{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 解读 {func}`~tvm.relay.transform.InferTypeLocal`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import relay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mSignature:\u001b[0m \u001b[0mrelay\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mInferTypeLocal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mSource:\u001b[0m   \n",
      "\u001b[0;32mdef\u001b[0m \u001b[0mInferTypeLocal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0;34m\"\"\"Infer the type of a single expr, reusing type information to do so.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m    This populates the checked_type field in expr. We assume existing type information\u001b[0m\n",
      "\u001b[0;34m    in the graph is correct!\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m    Parameters\u001b[0m\n",
      "\u001b[0;34m    ----------\u001b[0m\n",
      "\u001b[0;34m    expr: relay.Expr\u001b[0m\n",
      "\u001b[0;34m        The expression we want to know the type of\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m    Returns\u001b[0m\n",
      "\u001b[0;34m    -------\u001b[0m\n",
      "\u001b[0;34m    type: relay.Type\u001b[0m\n",
      "\u001b[0;34m        The type of the expression\u001b[0m\n",
      "\u001b[0;34m    \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0;32mreturn\u001b[0m \u001b[0m_ffi_api\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mInferTypeLocal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFile:\u001b[0m      /media/pc/data/lxw/ai/tvm/python/tvm/relay/transform/transform.py\n",
      "\u001b[0;31mType:\u001b[0m      function"
     ]
    }
   ],
   "source": [
    "relay.transform.InferTypeLocal??"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "{func}`~tvm.relay.transform.InferTypeLocal` 函数的作用是推断单个 `expr` 的类型，并重用类型信息来实现这一点。它会填充表达式中的 `checked_type` 字段。我们假设计算图中现有的类型信息是正确的！\n",
    "\n",
    "参数：\n",
    "- `expr`: `relay.Expr`，我们想要知道其类型的表达式\n",
    "\n",
    "返回值：\n",
    "- `type`: `relay.Type`，表达式的类型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```c++\n",
    "/*!\n",
    " * \\brief Infer the type of an expression, reusing existing type information.\n",
    " *\n",
    " * The result of type checking is a new expression with unambiguous\n",
    " * type information filled in for the given node only. The local\n",
    " * version can use existing type information populated throughout\n",
    " * the expression and assumes this information is correct. The local\n",
    " * version also avoids examining large amounts of the graph assuming\n",
    " * type information is filled in properly which makes it much faster if we\n",
    " * iteratively call type inference.\n",
    " *\n",
    " * \\return The type of the expression.\n",
    " */\n",
    "TVM_DLL Type InferTypeLocal(const Expr& expr);\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这个函数的作用是推断表达式的类型，并重用现有的类型信息。\n",
    "\n",
    "类型检查的结果是一个新的表达式，其中给定节点的不明确的类型信息被填充。局部版本可以使用整个表达式中填充的现有类型信息，并假设这些信息是正确的。局部版本还避免了检查大量的计算图，假设类型信息被正确填充，如果我们迭代地调用类型推断，这会使其更快。\n",
    "\n",
    "返回值：表达式的类型。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```c++\n",
    "Type InferTypeLocal(const Expr& expr) {\n",
    "  /*\n",
    "  This type inference differs from InferType in that it uses existing type information\n",
    "  to avoid recursing over much of the graph, and it only examines the type of the input\n",
    "  node. This makes it faster if you need to run type inference iteratively throughout\n",
    "  a pass for example.\n",
    "\n",
    "  However, it assumes any existing populated type inference is correct! If some populated\n",
    "  type inference is incorrect, an incorrect type may be returned or a type error will be\n",
    "  raised. If you know not all populated type fields are correct with the current graph,\n",
    "  you should use InferType() instead.\n",
    "  */\n",
    "  SameTypedSubgraphExtractor subgraph_extractor;\n",
    "  Expr sub_graph = subgraph_extractor(expr);\n",
    "\n",
    "  Type result_type;\n",
    "  result_type = relay::InferType(sub_graph)->checked_type();\n",
    "\n",
    "  expr->checked_type_ = result_type;\n",
    "  return result_type;\n",
    "}\n",
    "\n",
    "TVM_REGISTER_GLOBAL(\"relay._transform.InferTypeLocal\").set_body_typed([](const Expr& expr) {\n",
    "  return InferTypeLocal(expr);\n",
    "});\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "与 `InferType` 不同的是，它使用现有的类型信息来避免在计算图中递归遍历很多部分，并且只检查输入节点的类型。如果您需要在传递过程中迭代运行类型推断，这将使其更快。\n",
    "\n",
    "但是，它假设任何现有填充的类型推断都是正确的！如果某些填充的类型推断是错误的，可能会返回错误类型的结果或引发类型错误。如果您知道当前计算图中并非所有填充的类型字段都是正确的，则应使用 `InferType()` 代替。\n",
    "\n",
    "该函数首先创建 `SameTypedSubgraphExtractor` 对象，然后使用该对象从给定的表达式中提取子图。接下来，它调用 `relay::InferType()` 函数来推断子图的类型，并将结果存储在 `result_type` 变量中。最后，它将 `result_type` 分配给 `expr` 对象的 `checked_type_` 属性，并将其作为函数的返回值返回。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py312x",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
